use super::{RawTask, RawTaskContext};
use crate::{Error, Result};
use core::cell::Cell;
use core::future::Future;
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use core::pin::Pin;
use core::ptr;
use core::sync::atomic::{
    AtomicPtr,
    Ordering::{Acquire, Relaxed, Release},
};
use core::task::{Context, Poll};

/// 异步互斥锁.
pub struct Mutex<T> {
    waiting: AtomicPtr<RawTask>,
    task: Cell<*mut RawTask>,
    val: T,
}

unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T> Sync for Mutex<T> {}

impl<T> Mutex<T> {
    ///
    pub fn new(val: T) -> Self {
        Self {
            waiting: AtomicPtr::<RawTask>::new(ptr::null_mut::<RawTask>()),
            task: Cell::new(ptr::null_mut::<RawTask>()),
            val,
        }
    }
}

impl<T> Mutex<T> {
    /// 异步等待加锁成功，可能存在异步任务切换.
    pub async fn lock(&self) -> MutexGuard<'_, T> {
        LockFuture(self).await
    }

    /// 尝试加锁, 如果失败，则返回Err.
    pub async fn try_lock(&self) -> Result<MutexGuard<'_, T>> {
        TryLockFuture(self).await
    }
}

struct LockFuture<'a, T>(&'a Mutex<T>);

unsafe impl<T> Send for LockFuture<'_, T> {}

impl<'a, T> Future for LockFuture<'a, T> {
    type Output = MutexGuard<'a, T>;
    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        let task = ctx.task();
        if ptr::eq(self.0.task.get(), task) {
            return Poll::Ready(MutexGuard::new(self.get_mut().0));
        }

        let mut next = self.0.waiting.load(Acquire);
        loop {
            task.next = next;
            match self
                .0
                .waiting
                .compare_exchange_weak(next, task, Release, Acquire)
            {
                Ok(_) => break,
                Err(old) => next = old,
            }
        }
        if next.is_null() {
            self.0.task.set(task);
            Poll::Ready(MutexGuard::new(self.get_mut().0))
        } else {
            Poll::Pending
        }
    }
}

struct TryLockFuture<'a, T>(&'a Mutex<T>);

unsafe impl<T> Send for TryLockFuture<'_, T> {}

impl<'a, T> Future for TryLockFuture<'a, T> {
    type Output = Result<MutexGuard<'a, T>>;
    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        let task = ctx.task();
        let next = ptr::null_mut::<RawTask>();
        task.next = next;
        if self
            .0
            .waiting
            .compare_exchange(next, task, Release, Acquire)
            .is_ok()
        {
            self.0.task.set(task);
            return Poll::Ready(Ok(MutexGuard::new(self.get_mut().0)));
        }
        Poll::Ready(Err(Error::default()))
    }
}

pub struct MutexGuard<'a, T> {
    mutex: *mut Mutex<T>,
    mark: PhantomData<&'a mut Mutex<T>>,
}

unsafe impl<T> Send for MutexGuard<'_, T> {}

impl<'a, T> MutexGuard<'a, T> {
    fn new(mutex: &'a Mutex<T>) -> Self {
        Self {
            mutex: mutex as *const Mutex<T> as *mut Mutex<T>,
            mark: PhantomData,
        }
    }
}

impl<T> Deref for MutexGuard<'_, T> {
    type Target = T;
    fn deref(&self) -> &Self::Target {
        unsafe { &(*self.mutex).val }
    }
}

impl<T> DerefMut for MutexGuard<'_, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut (*self.mutex).val }
    }
}

impl<T> Drop for MutexGuard<'_, T> {
    fn drop(&mut self) {
        let mutex = unsafe { &*self.mutex };
        let mut next = mutex.waiting.load(Acquire);
        let task = mutex.task.replace(ptr::null_mut::<RawTask>());
        if ptr::eq(next, task) {
            match mutex
                .waiting
                .compare_exchange(task, ptr::null_mut::<RawTask>(), Relaxed, Acquire)
            {
                Ok(_) => return,
                Err(task) => next = task,
            }
        }
        while !ptr::eq(unsafe { (&*next).next }, task) {
            next = unsafe { (&*next).next };
        }
        unsafe { (&mut *next).next = ptr::null_mut() };

        mutex.task.set(next);
        unsafe { &*next }.wake();
    }
}

#[cfg(test)]
mod test {
    use crate::runtime::*;

    #[test]
    fn test_mutex() {
        let _ = Builder::new().build();
        let mutex = Mutex::new(1);
        let mutex = unsafe { &*(&mutex as *const Mutex<i32>) };

        async fn foo(mutex: &Mutex<i32>) -> i32 {
            for _ in 0..2 {
                let mut guard = mutex.lock().await;
                *guard += 1
            }
            let guard = mutex.lock().await;
            *guard
        }

        let h1 = spawn(foo(mutex));
        let h2 = spawn(foo(mutex));
        let h3 = spawn(foo(mutex));
        let h = h1
            .join()
            .unwrap()
            .max(h2.join().unwrap())
            .max(h3.join().unwrap());
        assert_eq!(h, 7);
    }
}
